# Comprehensive TUP CATE Analysis with Fixed Gamma and Heavy Interval Threshold

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime
warnings.filterwarnings('ignore')

class TeeOutput:
    """Class to write output to both console and file simultaneously."""
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()  # Ensure immediate write to file

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

class TUPCATEAllocator:
    """TUP CATE allocation algorithm with fixed gamma=0.5 and updated heavy interval threshold."""

    def __init__(self, epsilon=0.1, gamma=0.5, delta=0.05, heavy_multiplier=1.6, random_seed=42):
        self.epsilon = epsilon
        self.gamma = gamma
        self.rho = gamma * np.sqrt(epsilon)
        self.delta = delta
        self.heavy_multiplier = heavy_multiplier  # New parameter for heavy interval threshold
        self.random_seed = random_seed
        np.random.seed(random_seed)

        print(f"TUP CATE Allocation Algorithm")
        print(f"ε = {epsilon}")
        print(f"√ε = {np.sqrt(epsilon):.6f}")
        print(f"γ = {gamma}")
        print(f"ρ = γ√ε = {self.rho:.6f}")
        print(f"Heavy multiplier = {heavy_multiplier}x")
        print(f"δ = {delta}")
        print("="*60)

    def process_tup_data(self, df, outcome_col=None):
        """Process TUP dataset for analysis"""
        print(f"Processing TUP data with {len(df)} observations")
        print(f"Available columns: {len(df.columns)} columns")

        # DEBUG: Check what columns we actually have
        target_cols = [col for col in df.columns if 'target' in col.lower()]
        consumption_cols = [col for col in df.columns if 'consumption' in col.lower()]
        outcome_cols = [col for col in df.columns if 'outcome' in col.lower()]

        print(f"Columns with 'target': {target_cols}")
        print(f"Columns with 'consumption': {consumption_cols}")
        print(f"Columns with 'outcome': {outcome_cols}")

        df_processed = df.copy()

        # Check for required columns
        if 'treatment' not in df_processed.columns:
            raise ValueError("Missing required 'treatment' column")

        # FIRST PRIORITY: If outcome column already exists, use it directly
        if 'outcome' in df_processed.columns:
            print("Found existing 'outcome' column - using directly")
            # Keep the existing outcome column

            # Look for baseline consumption
            baseline_cols = [col for col in df.columns if 'pc_exp_month_bl' in col or ('bl' in col and 'exp' in col)]
            if baseline_cols:
                df_processed['baseline_consumption'] = df_processed[baseline_cols[0]]
                print(f"Using {baseline_cols[0]} as baseline consumption")
            else:
                print("Warning: No baseline consumption found, but proceeding with existing outcome")
                df_processed['baseline_consumption'] = 0  # Placeholder

        # SECOND PRIORITY: If target_column_consumption already exists (from preprocessing), use it
        elif 'target_column_consumption' in df_processed.columns:
            print("Found existing target_column_consumption - using as outcome")
            df_processed['outcome'] = df_processed['target_column_consumption']

            # Look for baseline consumption
            baseline_cols = [col for col in df.columns if 'pc_exp_month_bl' in col or ('bl' in col and 'exp' in col)]
            if baseline_cols:
                df_processed['baseline_consumption'] = df_processed[baseline_cols[0]]
                print(f"Using {baseline_cols[0]} as baseline consumption")
            else:
                print("Warning: No baseline consumption found, but proceeding with existing outcome")
                df_processed['baseline_consumption'] = 0  # Placeholder

        # THIRD PRIORITY: If outcome column is explicitly specified, use it
        elif outcome_col and outcome_col in df_processed.columns:
            print(f"Using provided outcome column: {outcome_col}")
            df_processed['outcome'] = df_processed[outcome_col]
            baseline_consumption_cols = [col for col in df.columns if 'bl' in col and any(x in col.lower() for x in ['exp', 'consumption', 'income'])]
            if baseline_consumption_cols:
                df_processed['baseline_consumption'] = df_processed[baseline_consumption_cols[0]]
            else:
                df_processed['baseline_consumption'] = 0  # Default if no baseline found

        # Final cleaning
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['outcome', 'treatment'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} additional rows due to missing outcome/treatment")

        print(f"Final dataset: {final_size} households")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")
        print(f"Outcome statistics: mean={df_processed['outcome'].mean():.3f}, std={df_processed['outcome'].std():.3f}")

        if 'baseline_consumption' in df_processed.columns:
            print(f"Baseline consumption stats: mean={df_processed['baseline_consumption'].mean():.3f}, std={df_processed['baseline_consumption'].std():.3f}")

        return df_processed

    def create_baseline_poverty_groups(self, df, n_groups=30, min_size=6):
        """Create groups by baseline consumption quintiles/deciles."""
        print(f"Creating baseline poverty groups (target: {n_groups})")

        if 'baseline_consumption' not in df.columns:
            print("No baseline consumption found, using first available consumption measure")
            consumption_cols = [col for col in df.columns if 'pc_exp' in col and 'bl' in col]
            if consumption_cols:
                baseline_col = consumption_cols[0]
                df['baseline_consumption'] = df[baseline_col]
            else:
                print("No baseline consumption measures found")
                return []

        # Create consumption-based groups
        consumption = df['baseline_consumption'].fillna(df['baseline_consumption'].median())

        # Create percentile groups
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(consumption, percentiles)
        bins = np.digitize(consumption, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'poverty_level_{i}',
                    'indices': indices,
                    'type': 'baseline_poverty'
                })

        print(f"Created {len(groups)} baseline poverty groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_demographics_groups(self, df, min_size=6):
        """Create groups by demographic characteristics relevant to TUP."""
        print(f"Creating TUP demographics groups")

        # Look for demographic variables commonly in TUP datasets
        potential_features = []

        # Look for household head characteristics
        demo_patterns = ['female', 'gender', 'head', 'age', 'education', 'literate',
                        'married', 'widow', 'household_size', 'children', 'caste', 'religion']

        for pattern in demo_patterns:
            matching_cols = [col for col in df.columns if pattern in col.lower() and not col.startswith('gkt')]
            potential_features.extend(matching_cols)

        # Remove duplicates and check which ones have reasonable variation
        potential_features = list(set(potential_features))

        available_features = []
        for col in potential_features:
            if col in df.columns and df[col].notna().sum() > 0:
                # Check if it's not too sparse or too uniform
                unique_vals = df[col].nunique()
                if 2 <= unique_vals <= 10:  # Reasonable number of categories
                    available_features.append(col)

        if len(available_features) == 0:
            print("No suitable demographic variables found, creating simple binary splits")
            # Create simple splits based on baseline consumption
            if 'baseline_consumption' in df.columns:
                median_consumption = df['baseline_consumption'].median()
                df['consumption_above_median'] = (df['baseline_consumption'] > median_consumption).astype(int)
                available_features = ['consumption_above_median']
            else:
                return []

        print(f"Using demographic features: {available_features}")

        # Limit to top 3 features to avoid too many combinations
        if len(available_features) > 3:
            available_features = available_features[:3]

        # Remove rows with missing values in these features
        df_clean = df.dropna(subset=available_features)
        print(f"After removing missing values: {len(df_clean)}/{len(df)} households")

        if len(df_clean) == 0:
            return []

        # Get unique combinations
        groups = []
        unique_combinations = df_clean[available_features].drop_duplicates()
        print(f"Found {len(unique_combinations)} unique demographic combinations")

        for combo_idx, (idx, combo) in enumerate(unique_combinations.iterrows()):
            mask = pd.Series(True, index=df.index)
            combo_description = []

            for feature in available_features:
                mask = mask & (df[feature] == combo[feature])
                combo_description.append(f"{feature}={combo[feature]}")

            indices = df[mask].index.tolist()
            combo_id = "_".join(combo_description)

            if len(indices) >= min_size:
                groups.append({
                    'id': combo_id,
                    'indices': indices,
                    'type': 'demographics'
                })

        print(f"Created {len(groups)} demographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_village_groups(self, df, min_size=6):
        """Create groups by village-level characteristics from TUP survey data."""
        print(f"Creating village-based groups from survey indicators (min_size={min_size})")

        # Look for location-type indicators that can proxy for villages
        location_patterns = ['district', 'rural', 'urban', 'capital', 'area', 'upazila', 'thana']
        location_cols = []

        for pattern in location_patterns:
            matching_cols = [col for col in df.columns
                           if pattern in col.lower() and 'bl_' in col
                           and not any(x in col.lower() for x in ['gram flour', 'food', 'loan'])]
            location_cols.extend(matching_cols)

        location_cols = list(set(location_cols))  # Remove duplicates

        if location_cols:
            print(f"Found location indicator columns: {location_cols[:3]}...")  # Show first 3

            # Use the first location column to create groups
            location_col = location_cols[0]
            print(f"Using location indicator: {location_col}")

            # For binary indicator columns, group by the indicator value
            groups = []
            for location_value in df[location_col].unique():
                if pd.isna(location_value):
                    continue

                indices = df[df[location_col] == location_value].index.tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'location_{location_col}_{location_value}',
                        'indices': indices,
                        'type': 'village'
                    })
        else:
            print("No location indicators found")

        print(f"Raw geographic groups created: {len(groups)}")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        print(f"Balanced geographic groups after filtering: {len(balanced_groups)}")

        return balanced_groups
    
        """Create groups by geographic/village characteristics."""
        print(f"Creating geographic groups")

        # Look for geographic identifiers
        geo_patterns = ['village', 'block', 'district', 'area', 'region', 'location', 'gram']
        geo_cols = []

        for pattern in geo_patterns:
            matching_cols = [col for col in df.columns if pattern in col.lower()]
            geo_cols.extend(matching_cols)

        geo_cols = list(set(geo_cols))  # Remove duplicates

        # Use the first available geographic variable
        geo_col = geo_cols[0]
        print(f"Using geographic variable: {geo_col}")

        groups = []
        for geo_id in df[geo_col].unique():
            if pd.isna(geo_id):
                continue

            indices = df[df[geo_col] == geo_id].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'geo_{geo_col}_{geo_id}',
                    'indices': indices,
                    'type': 'geographic'
                })

        print(f"Created {len(groups)} geographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_causal_forest_groups(self, df, n_groups=30, min_size=6):
        """Create groups using Random Forest to predict treatment effects."""
        print(f"Creating causal forest groups (target: {n_groups})")

        # Exclude outcome, treatment, and consumption change columns
        exclude_patterns = ['outcome', 'treatment', 'el1', 'el2', 'el3', 'el4', 'total_score']
        feature_cols = [col for col in df.columns
                       if not any(pattern in col for pattern in exclude_patterns)]

        X = df[feature_cols].copy()

        # Handle different data types properly
        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                try:
                    X[col] = LabelEncoder().fit_transform(X[col].astype(str))
                except:
                    X[col] = 0  # Default for problematic columns
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        # Fill missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(0)

        # Train separate models
        treated_mask = df['treatment'] == 1
        control_mask = df['treatment'] == 0

        if treated_mask.sum() < 5 or control_mask.sum() < 5:
            print("Not enough treated or control observations for causal forest")
            return []

        rf_treated = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)
        rf_control = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)

        rf_treated.fit(X[treated_mask], df.loc[treated_mask, 'outcome'])
        rf_control.fit(X[control_mask], df.loc[control_mask, 'outcome'])

        # Predict CATE and cluster
        pred_cate = rf_treated.predict(X) - rf_control.predict(X)
        cluster_features = np.column_stack([X.values, pred_cate.reshape(-1, 1)])
        cluster_features = StandardScaler().fit_transform(cluster_features)

        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'causal_forest_{i}',
                    'indices': indices,
                    'type': 'causal_forest'
                })

        print(f"Created {len(groups)} causal forest groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_propensity_groups(self, df, n_groups=50, min_size=6):
        """Create groups based on propensity score strata."""
        print(f"Creating propensity score groups (target: {n_groups})")

        # Exclude outcome and treatment columns
        feature_cols = [col for col in df.columns
                       if col not in ['treatment', 'outcome']]

        X = df[feature_cols].copy()

        # Handle different data types properly
        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                try:
                    X[col] = LabelEncoder().fit_transform(X[col].astype(str))
                except:
                    X[col] = 0
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        # Fill missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(0)

        # Get propensity scores
        try:
            prop_scores = cross_val_predict(
                LogisticRegression(random_state=self.random_seed, max_iter=1000),
                X, df['treatment'], method='predict_proba', cv=5
            )[:, 1]
        except Exception as e:
            print(f"Error computing propensity scores: {e}")
            return []

        # Create strata
        quantiles = np.linspace(0, 1, n_groups + 1)
        bins = np.digitize(prop_scores, np.quantile(prop_scores, quantiles)) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'propensity_{i}',
                    'indices': indices,
                    'type': 'propensity'
                })

        print(f"Created {len(groups)} propensity groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_asset_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on baseline asset ownership."""
        print(f"Creating asset-based groups (target: {n_groups})")

        # Look for asset-related variables
        asset_patterns = ['asset', 'livestock', 'land', 'house', 'own']
        asset_cols = []

        for pattern in asset_patterns:
            matching_cols = [col for col in df.columns
                           if pattern in col.lower() and 'bl' in col]  # Baseline assets
            asset_cols.extend(matching_cols)

        asset_cols = list(set(asset_cols))

        if not asset_cols:
            print("No asset variables found, using baseline consumption as proxy")
            if 'baseline_consumption' in df.columns:
                asset_score = df['baseline_consumption']
            else:
                print("No suitable asset or consumption variables found")
                return []
        else:
            print(f"Using asset variables: {asset_cols[:5]}")  # Show first 5

            # Create asset score from available variables - handle different data types
            asset_df = df[asset_cols].copy()
            processed_assets = []

            for col in asset_cols:
                col_data = asset_df[col]

                # Handle different data types
                if col_data.dtype == 'object' or col_data.dtype.name == 'category':
                    # For categorical/string columns, convert to binary (has asset vs doesn't)
                    # Assume non-null, non-empty, non-zero values indicate asset ownership
                    try:
                        # Try to convert to numeric first
                        numeric_version = pd.to_numeric(col_data, errors='coerce')
                        if not numeric_version.isna().all():
                            processed_assets.append(numeric_version.fillna(0))
                        else:
                            # Convert categorical to binary
                            binary_version = (~col_data.isna() &
                                            (col_data != '') &
                                            (col_data != '0') &
                                            (col_data.astype(str).str.upper() != 'NO')).astype(int)
                            processed_assets.append(binary_version)
                    except:
                        # Fallback: just check if not null
                        binary_version = (~col_data.isna()).astype(int)
                        processed_assets.append(binary_version)

                elif col_data.dtype == 'bool':
                    processed_assets.append(col_data.astype(int))

                else:
                    # Numeric columns - use as is, fill NaN with 0
                    processed_assets.append(col_data.fillna(0))

            if processed_assets:
                # Create asset score as sum of processed asset indicators
                asset_score = pd.concat(processed_assets, axis=1).sum(axis=1)
            else:
                print("Could not process any asset variables")
                return []

        # Create asset-based groups
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(asset_score, percentiles)
        bins = np.digitize(asset_score, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'asset_level_{i}',
                    'indices': indices,
                    'type': 'assets'
                })

        print(f"Created {len(groups)} asset groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_livelihood_groups(self, df, min_size=6):
        """Create groups based on baseline livelihood strategies."""
        print(f"Creating livelihood strategy groups")

        # Look for occupation/livelihood variables
        livelihood_patterns = ['occup', 'work', 'job', 'labor', 'employ', 'income_source']
        livelihood_cols = []

        for pattern in livelihood_patterns:
            matching_cols = [col for col in df.columns
                           if pattern in col.lower() and 'bl' in col]
            livelihood_cols.extend(matching_cols)

        livelihood_cols = list(set(livelihood_cols))

        groups = []
        for livelihood_type in df[livelihood_col].unique():
            if pd.isna(livelihood_type):
                continue

            indices = df[df[livelihood_col] == livelihood_type].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'livelihood_{livelihood_type}',
                    'indices': indices,
                    'type': 'livelihood'
                })

        print(f"Created {len(groups)} livelihood groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Ensure treatment balance and compute group CATE."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            cate = treated_outcomes.mean() - control_outcomes.mean()

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1]."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.3f}, {max_cate:.3f}] → [0, 1]")
        return groups

    def plot_cate_distribution(self, groups, title_suffix=""):
        """Plot CATE distribution."""
        original_cates = [g['cate'] for g in groups]
        normalized_cates = [g['normalized_cate'] for g in groups]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

        ax1.hist(original_cates, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
        ax1.set_xlabel('Original CATE')
        ax1.set_ylabel('Frequency')
        ax1.set_title(f'Original CATE Distribution{title_suffix}')
        ax1.grid(True, alpha=0.3)

        ax2.hist(normalized_cates, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
        ax2.set_xlabel('Normalized CATE (τ)')
        ax2.set_ylabel('Frequency')
        ax2.set_title(f'Normalized CATE Distribution{title_suffix}')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    def estimate_tau(self, true_tau, accuracy):
        """Estimate tau using Hoeffding's inequality with Bernoulli samples."""
        sample_size = int(np.ceil(np.log(2/self.delta) / (2 * accuracy**2)))
        samples = np.random.binomial(1, true_tau, sample_size)
        return np.mean(samples), sample_size

    def run_single_trial(self, groups, epsilon_val, trial_seed):
        """Run allocation algorithm for single trial with fixed gamma."""
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])
        rho = self.gamma * np.sqrt(epsilon_val)  # Use fixed gamma

        # Estimate all tau values using rho accuracy
        tau_estimates_rho = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, rho)
            tau_estimates_rho.append(estimate)
        tau_estimates_rho = np.array(tau_estimates_rho)

        # Also estimate using epsilon accuracy for comparison
        tau_estimates_eps = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, epsilon_val)
            tau_estimates_eps.append(estimate)
        tau_estimates_eps = np.array(tau_estimates_eps)

        results = []

        for K in range(1, n_groups):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_value = np.sum(tau_true[optimal_indices])

            rho_indices = np.argsort(tau_estimates_rho)[-K:]
            rho_value = np.sum(tau_true[rho_indices])

            eps_indices = np.argsort(tau_estimates_eps)[-K:]
            eps_value = np.sum(tau_true[eps_indices])

            rho_ratio = rho_value / optimal_value if optimal_value > 0 else 0
            eps_ratio = eps_value / optimal_value if optimal_value > 0 else 0
            rho_success = rho_ratio >= (1 - epsilon_val)
            eps_success = eps_ratio >= (1 - epsilon_val)

            tau_k_est = tau_estimates_rho[rho_indices[0]]
            a2_lower = tau_k_est
            a2_upper = tau_k_est + 2 * rho
            units_in_a2 = np.sum((tau_estimates_rho >= a2_lower) & (tau_estimates_rho <= a2_upper))
            expected_a2 = 2 * rho * n_groups
            # Updated heavy interval detection with 1.6x multiplier
            is_heavy = units_in_a2 > self.heavy_multiplier * expected_a2

            results.append({
                'K': K,
                'optimal_value': optimal_value,
                'rho_value': rho_value,
                'eps_value': eps_value,
                'rho_ratio': rho_ratio,
                'eps_ratio': eps_ratio,
                'rho_success': rho_success,
                'eps_success': eps_success,
                'is_heavy': is_heavy,
                'tau_k_est': tau_k_est,
                'units_in_a2': units_in_a2
            })

        return results, tau_estimates_rho

    def find_recovery_units(self, K, tau_true, tau_estimates, epsilon_val):
        """Find minimum units needed to achieve 1-epsilon performance."""
        n_groups = len(tau_true)

        # Original allocation (using rho estimates)
        rho_indices = np.argsort(tau_estimates)[-K:]
        optimal_value = np.sum(tau_true[np.argsort(tau_true)[-K:]])

        # Remaining candidates (sorted by estimate, best first)
        remaining_indices = np.argsort(tau_estimates)[:-K][::-1]

        # Test adding 1 to 10 additional units
        for extra in range(1, 11):
            if extra > len(remaining_indices):
                break

            expanded_indices = np.concatenate([rho_indices, remaining_indices[:extra]])
            expanded_value = np.sum(tau_true[expanded_indices])

            if expanded_value / optimal_value >= (1 - epsilon_val):
                return extra

        return None  # Need more than 10 units

    def find_closest_working_budget(self, failed_K, trial_results):
        """Find closest budget that works for a failed budget."""
        working_budgets = [r['K'] for r in trial_results if r['rho_success']]

        if not working_budgets:
            return None, None

        # Distance to any working budget (either direction)
        distances_any = [abs(K - failed_K) for K in working_budgets]
        min_distance_any = min(distances_any)

        # Distance to smaller working budget (underspending)
        smaller_working = [K for K in working_budgets if K < failed_K]
        if smaller_working:
            min_distance_smaller = failed_K - max(smaller_working)
        else:
            min_distance_smaller = None

        return min_distance_any, min_distance_smaller

    def analyze_method(self, groups, epsilon_val, n_trials=30):
        """Analyze single method with fixed gamma and updated heavy threshold."""
        print(f"\nAnalyzing {len(groups)} groups with ε={epsilon_val}, γ={self.gamma}")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        trial_data = []

        for trial in range(n_trials):
            print(f"Trial {trial + 1}/{n_trials}...")

            # Run single trial
            trial_results, tau_estimates = self.run_single_trial(groups, epsilon_val, trial)

            # Analyze failures
            failed_results = [r for r in trial_results if not r['rho_success']]
            failed_budgets = [r['K'] for r in failed_results]

            # Check which failed budgets are heavy with true tau values
            failed_heavy_estimated = []
            failed_heavy_true = []
            rho = self.gamma * np.sqrt(epsilon_val)

            for failed_result in failed_results:
                K = failed_result['K']
                # Heavy with estimated values (already computed)
                failed_heavy_estimated.append(failed_result['is_heavy'])

                # Check heavy with true tau values
                tau_k_true = tau_true[np.argsort(tau_true)[-K:]][0]  # True smallest in top-K
                a2_lower_true = tau_k_true
                a2_upper_true = tau_k_true + 2 * rho
                units_in_a2_true = np.sum((tau_true >= a2_lower_true) & (tau_true <= a2_upper_true))
                expected_a2_true = 2 * rho * n_groups
                is_heavy_true = units_in_a2_true > self.heavy_multiplier * expected_a2_true
                failed_heavy_true.append(is_heavy_true)

            # Print trial summary
            print(f"  Failed budgets: {failed_budgets}")

            # Print heavy vectors
            if len(failed_budgets) > 0:
                estimated_clean = [bool(x) for x in failed_heavy_estimated]
                true_clean = [bool(x) for x in failed_heavy_true]
                print(f"  HEAVY INTERVALS - Estimated: {estimated_clean}")
                print(f"  HEAVY INTERVALS - True τ_K:   {true_clean}")

            # Count total heavy intervals and failed budgets in heavy intervals
            total_heavy = sum(r['is_heavy'] for r in trial_results)
            failed_heavy = sum(r['is_heavy'] for r in failed_results)

            # Recovery analysis
            recovery_units = []
            distances_to_working_any = []
            distances_to_working_smaller = []

            for failed_result in failed_results:
                K = failed_result['K']

                # Find recovery units needed
                recovery = self.find_recovery_units(K, tau_true, tau_estimates, epsilon_val)
                if recovery is not None:
                    recovery_units.append(recovery)

                # Find distances to closest working budgets
                distance_any, distance_smaller = self.find_closest_working_budget(K, trial_results)
                if distance_any is not None:
                    distances_to_working_any.append(distance_any)
                if distance_smaller is not None:
                    distances_to_working_smaller.append(distance_smaller)

            trial_info = {
                'trial': trial,
                'failed_budgets': failed_budgets,
                'num_failures': len(failed_results),
                'total_heavy': total_heavy,
                'failed_heavy': failed_heavy,
                'failed_heavy_estimated': failed_heavy_estimated,
                'failed_heavy_true': failed_heavy_true,
                'recovery_units': recovery_units,
                'distances_to_working_any': distances_to_working_any,
                'distances_to_working_smaller': distances_to_working_smaller
            }

            trial_data.append(trial_info)

            print(f"  Failures: {len(failed_results)}, Total heavy: {total_heavy}, Failed heavy: {failed_heavy}")
            if recovery_units:
                print(f"  Recovery units: μ={np.mean(recovery_units):.1f}, med={np.median(recovery_units):.0f}, max={np.max(recovery_units)}")
            if distances_to_working_any:
                print(f"  Distance any: μ={np.mean(distances_to_working_any):.1f}, med={np.median(distances_to_working_any):.0f}, max={np.max(distances_to_working_any)}")
            if distances_to_working_smaller:
                print(f"  Distance smaller: μ={np.mean(distances_to_working_smaller):.1f}, med={np.median(distances_to_working_smaller):.0f}, max={np.max(distances_to_working_smaller)}")
            else:
                print(f"  Distance smaller: No smaller working budgets found")

        return trial_data

    def print_method_summary(self, method_name, trial_data, n_groups, epsilon_val):
        """Print summary statistics for a method."""
        budget_10pct_threshold = max(1, int(0.1 * n_groups))

        print(f"\n{'='*100}")
        print(f"SUMMARY - {method_name} - ε={epsilon_val} - {n_groups} GROUPS")
        print("="*100)
        print(f"{'Fail μ':<7} {'Fail σ':<7} {'FailR% μ':<9} {'FailR% σ':<9} {'TotHvy':<8} {'FailHvy':<9} {'Rec μ':<7} {'Rec med':<8} {'Rec max':<8} {'DAny μ':<8} {'DAny σ':<10} {'DAny max':<10} {'DSmall μ':<10} {'DSmall σ':<12} {'DSmall max':<12}")
        print("-"*120)

        # Aggregate statistics across all trials - ALL BUDGETS
        all_failures = [t['num_failures'] for t in trial_data]
        all_total_heavy = [t['total_heavy'] for t in trial_data]
        all_failed_heavy = [t['failed_heavy'] for t in trial_data]
        all_recovery = []
        all_distances_any = []
        all_distances_smaller = []

        for t in trial_data:
            all_recovery.extend(t['recovery_units'])
            all_distances_any.extend(t['distances_to_working_any'])
            all_distances_smaller.extend(t['distances_to_working_smaller'])

        avg_failures = np.mean(all_failures)
        std_failures = np.std(all_failures)
        avg_failure_rate = avg_failures / (n_groups - 1) * 100
        std_failure_rate = std_failures / (n_groups - 1) * 100
        avg_total_heavy = np.mean(all_total_heavy)
        avg_failed_heavy = np.mean(all_failed_heavy)

        # Recovery statistics
        if all_recovery:
            recovery_mean = np.mean(all_recovery)
            recovery_med = np.median(all_recovery)
            recovery_max = np.max(all_recovery)
        else:
            recovery_mean = recovery_med = recovery_max = np.nan

        # Distance statistics - any direction
        if all_distances_any:
            distance_any_mean = np.mean(all_distances_any)
            distance_any_std = np.std(all_distances_any)
            distance_any_max = np.max(all_distances_any)
        else:
            distance_any_mean = distance_any_std = distance_any_max = np.nan

        # Distance statistics - smaller only
        if all_distances_smaller:
            distance_smaller_mean = np.mean(all_distances_smaller)
            distance_smaller_std = np.std(all_distances_smaller)
            distance_smaller_max = np.max(all_distances_smaller)
        else:
            distance_smaller_mean = distance_smaller_std = distance_smaller_max = np.nan

        print(f"{avg_failures:<7.1f} {std_failures:<7.1f} {avg_failure_rate:<9.1f} {std_failure_rate:<9.1f} {avg_total_heavy:<8.1f} {avg_failed_heavy:<9.1f} "
              f"{recovery_mean:<7.1f} {recovery_med:<8.0f} {recovery_max:<8.0f} "
              f"{distance_any_mean:<8.1f} {distance_any_std:<10.1f} {distance_any_max:<10.0f} "
              f"{distance_smaller_mean:<10.1f} {distance_smaller_std:<12.1f} {distance_smaller_max:<12.0f}")

        return {
            'avg_failures': avg_failures,
            'failure_rate_pct': avg_failure_rate,
            'avg_recovery': recovery_mean,
            'n_groups': n_groups
        }


def run_comprehensive_tup_analysis(df_tup, epsilon_values=None, n_trials=30, log_file=None, outcome_col=None):
    """Run comprehensive TUP analysis with all methods, fixed gamma=0.5, and 1.6x heavy threshold."""

    if epsilon_values is None:
        epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    if log_file is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = f"tup_comprehensive_analysis_gamma05_{timestamp}.txt"

    # Redirect output to both console and file
    original_stdout = sys.stdout
    tee = TeeOutput(log_file)
    sys.stdout = tee

    try:
        print("COMPREHENSIVE TUP ANALYSIS - ALL METHODS, FIXED γ=0.5, HEAVY THRESHOLD=1.6x")
        print(f"Log file: {log_file}")
        print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print("="*100)

        # Define all TUP-specific grouping methods
        methods = [
            ('Village', lambda allocator, df: allocator.create_village_groups(df, min_size=6)),
            ('Baseline Poverty', lambda allocator, df: allocator.create_baseline_poverty_groups(df, n_groups=30, min_size=6)),
            ('Demographics', lambda allocator, df: allocator.create_demographics_groups(df, min_size=6)),
            ('Geographic', lambda allocator, df: allocator.create_geographic_groups(df, min_size=6)),
            ('Assets', lambda allocator, df: allocator.create_asset_groups(df, n_groups=30, min_size=6)),
            ('Livelihood', lambda allocator, df: allocator.create_livelihood_groups(df, min_size=6)),
            ('Causal Forest 30', lambda allocator, df: allocator.create_causal_forest_groups(df, n_groups=30, min_size=6)),
            ('Causal Forest 50', lambda allocator, df: allocator.create_causal_forest_groups(df, n_groups=50, min_size=6)),
            ('Propensity Score', lambda allocator, df: allocator.create_propensity_groups(df, n_groups=50, min_size=6))
        ]

        all_results = {}

        for method_name, method_func in methods:
            print(f"\n{'='*120}")
            print(f"ANALYZING TUP METHOD: {method_name}")
            print("="*120)

            method_results = []

            for eps in epsilon_values:
                print(f"\n{'='*100}")
                print(f"METHOD: {method_name} | EPSILON = {eps}")
                print("="*100)

                # Initialize allocator with fixed gamma=0.5 and 1.6x heavy threshold
                allocator = TUPCATEAllocator(epsilon=eps, gamma=0.5, heavy_multiplier=1.6)
                df_processed = allocator.process_tup_data(df_tup, outcome_col=outcome_col)

                try:
                    # Create groups using this method
                    groups = method_func(allocator, df_processed)

                    if len(groups) < 3:
                        print(f"Too few groups ({len(groups)}) for {method_name} with ε = {eps} - skipping")
                        continue

                    groups = allocator.normalize_cates(groups)

                    # Show CATE distribution
                    allocator.plot_cate_distribution(groups, f" ({method_name}, ε={eps})")

                    # Run analysis for this epsilon and method
                    trial_data = allocator.analyze_method(groups, eps, n_trials)

                    # Print method summary
                    stats = allocator.print_method_summary(method_name, trial_data, len(groups), eps)

                    epsilon_result = {
                        'method': method_name,
                        'epsilon': eps,
                        'sqrt_epsilon': np.sqrt(eps),
                        'gamma': 0.5,
                        'rho': 0.5 * np.sqrt(eps),
                        'groups': groups,
                        'trial_data': trial_data,
                        'stats': stats
                    }

                    method_results.append(epsilon_result)

                except Exception as e:
                    print(f"Error with {method_name} at ε = {eps}: {e}")
                    continue

            all_results[method_name] = method_results

            # Add method-specific summary table after all epsilons for this method
            if method_results:
                print(f"\n{'='*120}")
                print(f"METHOD SUMMARY - {method_name} - ALL EPSILON VALUES")
                print("="*120)
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for eps_result in method_results:
                    eps = eps_result['epsilon']
                    sqrt_eps = eps_result['sqrt_epsilon']
                    gamma = eps_result['gamma']
                    rho = eps_result['rho']
                    n_groups = len(eps_result['groups'])
                    stats = eps_result['stats']

                    print(f"{eps:<8} {sqrt_eps:<10.6f} {gamma:<6} {rho:<10.6f} "
                          f"{n_groups:<8} {stats['avg_failures']:<8.1f} {stats['failure_rate_pct']:<8.1f} "
                          f"{stats['avg_recovery']:<8.1f}")
                print("="*120)

        # Create comprehensive summary across all methods and epsilon values
        print(f"\n{'='*200}")
        print("COMPREHENSIVE SUMMARY - ALL TUP METHODS AND EPSILON VALUES")
        print("="*200)

        # Create summary table
        summary_data = []

        for method_name, method_results in all_results.items():
            if not method_results:
                continue

            print(f"\n{'-'*100}")
            print(f"TUP METHOD: {method_name}")
            print("-"*100)

            for eps_result in method_results:
                eps = eps_result['epsilon']
                sqrt_eps = eps_result['sqrt_epsilon']
                gamma = eps_result['gamma']
                rho = eps_result['rho']
                n_groups = len(eps_result['groups'])
                stats = eps_result['stats']

                summary_data.append({
                    'method': method_name,
                    'epsilon': eps,
                    'sqrt_eps': sqrt_eps,
                    'gamma': gamma,
                    'rho': rho,
                    'avg_failures': stats['avg_failures'],
                    'failure_rate_pct': stats['failure_rate_pct'],
                    'avg_recovery': stats['avg_recovery'],
                    'n_groups': stats['n_groups']
                })

            # Print method-specific table
            method_data = [d for d in summary_data if d['method'] == method_name]
            if method_data:
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for data in method_data:
                    print(f"{data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                          f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                          f"{data['avg_recovery']:<8.1f}")

        # Overall summary table
        print(f"\n{'='*200}")
        print("OVERALL SUMMARY TABLE - ALL TUP METHODS COMBINED")
        print("="*200)
        print(f"{'Method':<18} {'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
        print("-" * 100)

        for data in summary_data:
            print(f"{data['method']:<18} {data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                  f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                  f"{data['avg_recovery']:<8.1f}")

        # TUP-specific analysis insights
        print(f"\n{'='*100}")
        print("KEY INSIGHTS FOR TUP DATASET")
        print("="*100)

        # Find best and worst performing methods for TUP
        if summary_data:
            # Average performance across all epsilon values per method
            method_performance = {}
            for method_name in all_results.keys():
                method_data = [d for d in summary_data if d['method'] == method_name]
                if method_data:
                    avg_failure_rate = np.mean([d['failure_rate_pct'] for d in method_data])
                    method_performance[method_name] = avg_failure_rate

            if method_performance:
                best_method = min(method_performance, key=method_performance.get)
                worst_method = max(method_performance, key=method_performance.get)

                print(f"BEST PERFORMING TUP METHOD: {best_method}")
                print(f"  Average failure rate: {method_performance[best_method]:.1f}%")

                print(f"\nWORST PERFORMING TUP METHOD: {worst_method}")
                print(f"  Average failure rate: {method_performance[worst_method]:.1f}%")

                print(f"\nTUP METHOD RANKING (by average failure rate):")
                sorted_methods = sorted(method_performance.items(), key=lambda x: x[1])
                for i, (method, rate) in enumerate(sorted_methods, 1):
                    print(f"  {i}. {method}: {rate:.1f}%")

        # Effect of epsilon on TUP
        print(f"\nEFFECT OF EPSILON ON TUP:")
        epsilon_performance = {}
        for eps in epsilon_values:
            eps_data = [d for d in summary_data if d['epsilon'] == eps]
            if eps_data:
                avg_failure_rate = np.mean([d['failure_rate_pct'] for d in eps_data])
                epsilon_performance[eps] = avg_failure_rate

        if epsilon_performance:
            print(f"{'Epsilon':<10} {'Avg Failure Rate':<15} {'ρ = 0.5√ε':<12}")
            print("-" * 40)
            for eps in sorted(epsilon_performance.keys()):
                rho = 0.5 * np.sqrt(eps)
                print(f"{eps:<10} {epsilon_performance[eps]:<15.1f} {rho:<12.6f}")

        return all_results, summary_data

    finally:
        # Restore original stdout and close log file
        sys.stdout = original_stdout
        tee.close()


# Example usage for TUP dataset
if __name__ == "__main__":
    # Load TUP dataset (using the same preprocessing as in the original code)
    import pandas as pd
    from sklearn.ensemble import RandomForestRegressor

    def preprocess_tup_data(filepath):
        """Preprocess TUP data using the same approach as the original code."""
        # Load the data
        df1 = pd.read_stata(filepath)
        print(f"Loaded TUP data: {df1.shape}")

        # Follow the original preprocessing steps
        columns_to_drop_original = df1.columns[df1.columns.str.endswith('el1') |
                                      df1.columns.str.endswith('el2') |
                                      df1.columns.str.endswith('el3') |
                                      df1.columns.str.endswith('el4') | df1.columns.str.startswith('el')]

        categorical_columns = df1.select_dtypes(include=['object']).columns
        unique_value_counts = df1[categorical_columns].nunique()

        leave_categories = []
        keep_categories = []
        for i in categorical_columns:
            if unique_value_counts[i] > 2 * unique_value_counts.mean():
                leave_categories.append(i)
            else:
                keep_categories.append(i)

        # Keep only relevant columns
        object_columns = df1.select_dtypes(include=['object'])
        cols_keep = []
        for col in df1.columns:
            if col not in object_columns and col not in leave_categories and col not in columns_to_drop_original:
                cols_keep.append(col)

        df1_filtered = df1[cols_keep]
        df1_encoded = pd.get_dummies(df1_filtered)

        # Create target variable
        target_col_consumption = df1['pc_exp_month_el3'] - df1['pc_exp_month_bl']
        df1_encoded['target_column_consumption'] = target_col_consumption

        # Clean data
        df1_encoded = df1_encoded[df1_encoded['target_column_consumption'].notna()]
        df1_encoded = df1_encoded.dropna(axis=1, how='all')
        df1_encoded = df1_encoded.fillna(df1_encoded.mean())

        # Feature selection using Random Forest
        X = df1_encoded.drop(columns='target_column_consumption')
        y = df1_encoded['target_column_consumption']

        rf = RandomForestRegressor(n_estimators=100, random_state=42)
        rf.fit(X, y)

        feature_importances = rf.feature_importances_
        importance_df = pd.DataFrame({'Feature': X.columns, 'Importance': feature_importances})
        top_1000_features = importance_df.sort_values(by='Importance', ascending=False).head(1000)['Feature'].tolist()

        # Ensure treatment and baseline consumption are included
        if 'treatment' not in top_1000_features:
            top_1000_features.append('treatment')
        if 'pc_exp_month_bl' not in top_1000_features:
            top_1000_features.append('pc_exp_month_bl')

        df_final = df1_encoded[top_1000_features].copy()

        df_final['target_column_consumption'] = df1_encoded['target_column_consumption']
        df_final['outcome'] = df1_encoded['target_column_consumption']  # Also create 'outcome' column

        return df_final

    epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    results, summary = run_comprehensive_tup_analysis(
        df_tup,
        epsilon_values=epsilon_values,
        n_trials=30,
        log_file="tup_comprehensive_analysis_gamma05_heavy16.txt"
    )